#!/usr/bin/env python
# coding: utf-8
# Created on Mon Oct. 24 15:24:18 2022
# @author: Lu Jian
# Email:janelu@live.cn;

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

class PositionalEmbedding(nn.Layer):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()
        self.demb = demb
        inv_freq = 1 / (10000 ** (paddle.arange(0.0, demb, 2.0) / demb)).unsqueeze(0)
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = pos_seq.unsqueeze(1) * self.inv_freq
        pos_emb = paddle.concat([sinusoid_inp.sin(), sinusoid_inp.cos()], axis=-1)
        return pos_emb
        
class MultiHeadAttention(nn.Layer):
    def __init__(self,embed_dim,num_heads,dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
                 weight_attr=None,bias_attr=None,**kwargs):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.n_head = num_heads
        self.dropatt = nn.Dropout(dropatt, mode="upscale_in_train")
        self.d_head = embed_dim // num_heads
        assert self.d_head * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.q_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
        self.k_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
        self.v_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
        self.out_proj = nn.Linear(embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
        
        self.scale = self.d_head ** -0.5
    
    def compute_kv(self, key, value):
        k = paddle.transpose(paddle.reshape(self.k_proj(key), [0, 0, self.num_heads, self.head_dim]), [0, 2, 3, 1])
        v = paddle.transpose(paddle.reshape(self.v_proj(value), [0, 0, self.num_heads, self.head_dim]), [0, 2, 1, 3])
        return k, v    

    def attention(self,q,k,v,attn_mask=0):
        product = paddle.matmul(q,k) * self.scale + attn_mask
        weights = self.dropatt(F.softmax(product))
        out = paddle.reshape(paddle.transpose(paddle.matmul(weights,v),[0, 2, 1, 3]),[0, 0, -1])
        out = self.out_proj(out)
        return out,weights
    
        
    def forward(self, query, key, value, attn_mask= 0, mems=None):
        q =paddle.transpose(paddle.reshape(self.q_proj(query), [0, 0, self.num_heads, self.head_dim]), [0, 2, 1, 3])  
        if mems is not None:
            key = torch.cat([mems, key], 1)
            value = torch.cat([mems, value], 1)
        k,v = self.compute_kv(key,value)
        out,weights =self.attention(q,k,v,attn_mask)
        return out
    
    def _parallelogram_mask(self, h, w, left=False):
        mask =  paddle.ones((h, w))
        m = min(h, w)
        mask[:m,:m] = paddle.triu(mask[:m,:m])
        mask[-m:,-m:] = paddle.tril(mask[-m:,-m:])
        if left:
            return mask
        else:
            return mask.flip(0)
        
    def _shift(self, x, qlen, klen, mask, left=False):
        if qlen > 1:
            zero_pad = paddle.zeros((x.shape[0], qlen-1, x.shape[2], x.shape[3]),dtype=x.dtype)
        else:
            zero_pad = paddle.zeros(0,dtype=x.dtype)

        if left:
            mask = mask.flip(1)
            x_padded = paddle.concat([zero_pad, x], axis=1).expand([qlen, -1, -1, -1])
        else:
            x_padded = paddle.concat([x, zero_pad], axis=1).expand([qlen, -1, -1, -1])

        x = x_padded.masked_select(mask[:,:,None,None]).reshape([qlen, klen, x.shape[2], x.shape[3]])
        return x
    
    def _rel_shift(self, x, zero_triu=False):
        zero_pad = paddle.zeros((x.shape[0], 1, *x.shape[2:]), dtype=x.dtype)
        x_padded = paddle.concat([zero_pad, x], axis=1)
        x_padded = x_padded.reshape([x.shape[1] + 1, x.shape[0], *x.shape[2:]])
        x = x_padded[1:].reshape(x.shape)
        if zero_triu:
            ones = paddle.ones((x.shape[0], x.shape[1]))
            x = x * paddle.tril(ones, x.shape[0] - x.shape[0])[:,:,None,None]
        return x
    
class RelLearnableMultiHeadAttn(MultiHeadAttention):
    def __init__(self, *args, **kwargs):
        super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
        # r_emb: [klen, n_head, d_head], used for term B
        # r_w_bias: [n_head, d_head], used for term C
        # r_bias: [klen, n_head], used for term D

        qlen, bsz = w.shape[1], w.shape[0]

        if mems is not None:
            cat = paddle.cat([mems, w], 1)
            w_head_q = self.q_proj(cat[:,-qlen:])
            w_head_k = self.k_proj(cat)
            w_head_v = self.v_proj(cat)
        else:
            w_head_q = self.q_proj(w)
            w_head_k = self.k_proj(w)
            w_head_v = self.v_proj(w)

        klen = w_head_k.shape[1]

        w_head_q = w_head_q.reshape((bsz, self.n_head, qlen, self.d_head))
        w_head_k = w_head_k.reshape((bsz, self.n_head, klen, self.d_head))
        w_head_v = w_head_v.reshape((bsz, self.n_head, klen, self.d_head))

        if klen > r_emb.shape[0]:
            r_emb_pad = r_emb[0:1].expand(klen-r_emb.shape[0], -1, -1)
            r_emb = paddle.concat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen-r_bias.shape[0], -1)
            r_bias = paddle.concat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[:,None]                                   # qlen x bsz x n_head x d_head

        AC = paddle.einsum('bnid,bnjd->bnij', rw_head_q, w_head_k)             # qlen x klen x bsz x n_head
        B_ = paddle.einsum('bnid,jnd->bnij', w_head_q, r_emb)                  # qlen x klen x bsz x n_head
        D_ = r_bias.T[None, :,None]                                            # 1    x klen x 1   x n_head
        BD = self._rel_shift(B_ + D_)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score *= self.scale

        #### compute attention probability
        if attn_mask is not None and attn_mask.astype("bool").any():
            if attn_mask.dim() == 2:
                attn_score= paddle.where(attn_mask[:,None,None,:], paddle.zeros([1])-float('inf'),attn_score)
            elif attn_mask.dim() == 3:
                attn_score= paddle.where(attn_mask[:,None,:,:], paddle.zeros([1])-float('inf'),attn_score)

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, axis=-1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = paddle.einsum('bnij,bnjd->bind', attn_prob, w_head_v)

        attn_vec = attn_vec.reshape([
            attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head])
        output=self.out_proj(attn_vec)
        return output
    
class RelPartialLearnableMultiHeadAttn(MultiHeadAttention):
    def __init__(self, *args, **kwargs):
        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
        self.r_net = nn.Linear(self.embed_dim, self.embed_dim, bias_attr =False)

    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
        qlen,rlen, bsz = w.shape[1],r.shape[0] ,w.shape[0]
        if mems is not None:
            cat = paddle.concat([mems, w], 1)
            w_head_q = self.q_proj(cat[:,-qlen:])
            w_head_k = self.k_proj(cat)
            w_head_v = self.v_proj(cat)
        else:
            w_head_q = self.q_proj(w)
            w_head_k = self.k_proj(w)
            w_head_v = self.v_proj(w)
        
        r_head_k = self.r_net(r)
        klen = w_head_k.shape[1]
        
        w_head_q = w_head_q.reshape((bsz, self.n_head, qlen, self.d_head))
        w_head_k = w_head_k.reshape((bsz, self.n_head, klen, self.d_head))
        w_head_v = w_head_v.reshape((bsz, self.n_head, klen, self.d_head))
        
        r_head_k = r_head_k.reshape([rlen, self.n_head, self.d_head])

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias.unsqueeze(1)                                  # qlen x bsz x n_head x d_head

        AC = paddle.einsum('bnid,bnjd->bnij', rw_head_q, w_head_k)             # qlen x klen x bsz x n_head
        
        rr_head_q = w_head_q + r_r_bias.unsqueeze(1)
        BD = paddle.einsum('bnid,jnd->bnij', rr_head_q, r_head_k)                  # qlen x klen x bsz x n_head                                          # 1    x klen x 1   x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score *= self.scale

        #### compute attention probability
        if attn_mask is not None and attn_mask.astype("bool").any():
            if attn_mask.dim() == 2:
                attn_score= paddle.where(attn_mask[:,None,None,:], paddle.zeros([1])-float('inf'),attn_score)
            elif attn_mask.dim() == 3:
                attn_score= paddle.where(attn_mask[:,None,:,:], paddle.zeros([1])-float('inf'),attn_score)

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, axis=-1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = paddle.einsum('bnij,bnjd->bind', attn_prob, w_head_v)

        attn_vec = attn_vec.reshape([
            attn_vec.shape[0], attn_vec.shape[1], self.n_head * self.d_head])

        attn_out = self.out_proj(attn_vec)
        return attn_out
        
class RelLearnableDecoderLayer(nn.Layer):
    def __init__(self,d_model,n_head,dim_feedforward,dropout=0.1,dropatt=0.1,activation='GELU',normalize_before=False,**kwargs):
        self._config = locals()
        self._config.pop("__class__", None)
        super(RelLearnableDecoderLayer, self).__init__()
        self.normalize_before = normalize_before
        self.self_attn = RelLearnableMultiHeadAttn(d_model,n_head,dropatt=dropatt,**kwargs)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.norm1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = eval(f'nn.{activation}()')#getattr(F, activation)
        self.dropact = nn.Dropout(dropout, mode="upscale_in_train")
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
        self.norm2 = nn.LayerNorm(d_model)
        self._config.pop("self")

    def forward(self,dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
        if not self.normalize_before:
            residual = dec_inp
            tgt = self.self_attn(dec_inp,r_emb,r_w_bias,r_bias,
                                   attn_mask=dec_attn_mask,
                                   mems=mems)
            tgt = residual + self.dropout1(tgt)
            tgt = self.norm1(tgt)
            residual = tgt
            tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
            tgt = residual + self.dropout2(tgt)
            tgt = self.norm2(tgt)
            return tgt
        residual = dec_inp
        tgt = self.norm1(dec_inp)
        tgt = self.self_attn(tgt,r_emb,r_w_bias,r_bias,
                           attn_mask=dec_attn_mask,
                           mems=mems)
        tgt = residual + self.dropout1(tgt)
        residual = tgt
        tgt = self.norm2(tgt)
        tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
        tgt = residual + self.dropout2(tgt)
        return tgt

class RelPartialLearnableDecoderLayer(nn.Layer):
    def __init__(self,d_model,n_head,dim_feedforward,dropout=0.1,dropatt=0.1,activation='GELU',normalize_before=False,**kwargs):
        self._config = locals()
        self._config.pop("__class__", None)
        super(RelPartialLearnableDecoderLayer, self).__init__()
        self.normalize_before = normalize_before
        self.self_attn = RelPartialLearnableMultiHeadAttn(d_model,n_head,dropatt=dropatt,**kwargs)
        self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
        self.norm1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation = eval(f'nn.{activation}()')#getattr(F, activation)
        self.dropact = nn.Dropout(dropout, mode="upscale_in_train")
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout2 = nn.Dropout(dropout, mode="upscale_in_train")
        self.norm2 = nn.LayerNorm(d_model)
        self._config.pop("self")

    def forward(self,dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
        if not self.normalize_before:
            residual = dec_inp
            tgt = self.self_attn(dec_inp, r, r_w_bias, r_r_bias,
                                   attn_mask=dec_attn_mask,
                                   mems=mems)
            tgt = residual + self.dropout1(tgt)
            tgt = self.norm1(tgt)
            residual = tgt
            tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
            tgt = residual + self.dropout2(tgt)
            tgt = self.norm2(tgt)
            return tgt
        residual = dec_inp
        tgt = self.norm1(dec_inp)
        tgt = self.self_attn(tgt, r, r_w_bias, r_r_bias,
                           attn_mask=dec_attn_mask,
                           mems=mems)
        tgt = residual + self.dropout1(tgt)
        residual = tgt
        tgt = self.norm2(tgt)
        tgt = self.linear2(self.dropact(self.activation(self.linear1(tgt))))
        tgt = residual + self.dropout2(tgt)
        return tgt
        
class WordEmbedding(nn.Layer):      
    def __init__(self, n_token, d_embed,pad_id=0,dropout=0.1,**kwargs):
        super(WordEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.emb_scale = d_embed ** 0.5
        self.word_embeddings = nn.Embedding(n_token, d_embed, padding_idx=pad_id)
        self.layer_norm = nn.LayerNorm(d_embed)

    def forward(self, inp):
        embeddings = self.word_embeddings(inp)
        embeddings = self.layer_norm(embeddings)
        return embeddings
        
class ERNIE_XL(nn.Layer):
    def __init__(self, n_token, n_layer,n_head, d_model, d_inner,attn_type=0,
                 dropout=0.1, dropatt=0.1, normalize_before=False,
                 tgt_len=0, ext_len=0, mem_len=0,**kwargs):
        super(ERNIE_XL, self).__init__()
        self.n_token = n_token

        d_embed = d_model
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model//n_head

        self.word_emb = WordEmbedding(n_token, d_embed, dropout=dropout,**kwargs)
        
        
        self.drop = nn.Dropout(dropout)
        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len
        self.attn_type = attn_type
        self.layers = nn.LayerList()
        if attn_type == 0: # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
                        d_model,n_head,d_inner, dropout,dropatt=dropatt,
                        normalize_before=normalize_before,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len)
                )
        elif attn_type == 1: # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(
                        d_model,n_head,d_inner,dropout,dropatt=dropatt,
                        normalize_before=normalize_before,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                        
                    )
                )
        elif attn_type in [2, 3]: # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(
                        d_model,n_head, d_inner, dropout,
                        dropatt=dropatt, normalize_before=normalize_before)
                )

#         self.sample_softmax = sample_softmax
#         # use sampled softmax
#         if sample_softmax > 0:
#             self.out_layer = nn.Linear(d_model, n_token)
#             if tie_weight:
#                 self.out_layer.weight = self.word_emb.weight
#             self.tie_weight = tie_weight
#             self.sampler = LogUniformSampler(n_token, sample_softmax)

#         # use adaptive softmax (including standard softmax)
#         else:
#             self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, 
#                                                     cutoffs, div_val=div_val)

#             if tie_weight:
#                 for i in range(len(self.crit.out_layers)):
#                     self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

#             if tie_projs:
#                 for i, tie_proj in enumerate(tie_projs):
#                     if tie_proj and div_val == 1 and d_model != d_embed:
#                         self.crit.out_projs[i] = self.word_emb.emb_projs[0]
#                     elif tie_proj and div_val != 1:
#                         self.crit.out_projs[i] = self.word_emb.emb_projs[i]

#         self.same_length = same_length
#         self.clamp_len = clamp_len

        self._create_params()

#     def backward_compatible(self):
#         self.sample_softmax = -1

    def _create_params(self):
        if self.attn_type == 0: # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
            self.r_w_bias = paddle.create_parameter((self.n_head, self.d_head),"float32")
            self.r_r_bias = paddle.create_parameter((self.n_head, self.d_head),"float32")
        elif self.attn_type == 1: # learnable
            self.r_emb = paddle.create_parameter((
                    self.n_layer, self.max_klen, self.n_head, self.d_head))
            self.r_w_bias = paddle.create_parameter((
                    self.n_layer, self.n_head, self.d_head))
            self.r_bias =  paddle.create_parameter((
                    self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2: # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3: # absolute deeper SA
            self.r_emb = paddle.create_parameter((
                    self.n_layer, self.max_klen, self.n_head, self.d_head))

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    def init_mems(self):
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer+1):
                empty = paddle.empty([0,0,0],"float32")
                mems.append(empty)
            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        new_mems = []
        end_idx = mlen + max(0, qlen - 0 - self.ext_len)
        beg_idx = max(0, end_idx - self.mem_len)
        for i in range(len(hids)):
            cat = paddle.concat([mems[i], hids[i]],1)
            new_mems.append(cat[:,beg_idx:end_idx])

        return new_mems

    def _forward(self, dec_inp, mems=None):
        bsz ,qlen= dec_inp.shape

        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].shape[1] if mems is not None else 0
        klen = mlen + qlen
        
        hids = []
        if self.attn_type == 0: # default
            pos_seq = paddle.arange(klen-1, -1, -1.0, dtype=word_emb.dtype)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out, pos_emb, self.r_w_bias,
                        self.r_r_bias, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 1: # learnable
            core_out = self.drop(word_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                r_emb, r_bias = self.r_emb[i], self.r_bias[i]
                mems_i = None if mems is None else mems[i]
                core_out = layer(core_out, r_emb, self.r_w_bias[i],
                        r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 2: # absolute
            pos_seq = paddle.arange(klen - 1, -1, -1.0, dtype=word_emb.dtype)

            pos_emb = self.pos_emb(pos_seq)
            core_out = self.drop(word_emb + pos_emb[-qlen:])

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand([mlen-cur_size, -1, -1])
                        cur_emb = paddle.concat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.reshape([mlen, 1, -1])
                core_out += self.r_emb[i][-qlen:].reshape([qlen, 1, -1])
                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
                                 mems=mems_i)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems

    def forward(self, inp, mems = None):
        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
        # So, have to initialize size(0) mems inside the model forward.
        # Moreover, have to return new_mems to allow nn.DataParallel to piece
        # them together.
        if not mems: mems = self.init_mems()
        block_num = inp.shape[1]//self.mem_len
        for step in range(block_num):
            s=step*self.mem_len
            e=s+self.mem_len
            data=inp[:,s:e]
            hidden, mems = self._forward(data, mems=mems)
        return hidden